import copy, os
from torch import nn
import numpy as np
from State.utils import ObjDict
from tianshou.data import Batch
from Causal.dynamics import Dynamics
from Causal.ac_extractor import ACExtractor
from Causal.ac_infer.Network.network_utils import pytorch_model
import sys
import torch

def regenerate(args, environment, config):
    # instead of calling the regenerate in ac_infer, call one that uses the environment we received in init
    # also provides just enough in extractor and normalization so that they can be used
    # TODO: actually implement
    # needed components in normalization: (I think this component is unused)
    if os.path.join(sys.path[0],"Causal", "ac_infer") not in sys.path: sys.path.append(os.path.join(sys.path[0],"Causal", "ac_infer"))  
    from Causal.ac_infer.Environment.Normalization.pad_norm import PadNormalizationModule
    print(sys.path)  
    normalization = None # since we are already normalizing, we don't send the model a normalizer

    extractor = ACExtractor(args, environment, config)
    args.factor.first_obj_dim, args.factor.single_obj_dim, args.factor.object_dim, args.factor.all_obj_dim, args.factor.named_first_obj_dim = extractor._get_dims()
    args.factor.num_objects = len(environment.all_names)
    args.factor.name_idxes = extractor.name_idxes
    args.factor.name_idx = -1


    object_names = environment.object_names
    all_names = environment.all_names
    object_range, object_range_true, object_dynamics, instanced = environment.object_range, environment.object_range_true, environment.object_dynamics, environment.object_instanced
    # # replace the ranges with different ranges if we are using encodings instead of the object ranges
    # if encoding_dim > 0:
    #     for n in environment.object_range.keys():
    #         if n not in ["Action", "Reward", "Done"]: # TODO: in the future could have issues if there are other reserved keys
    #             object_range[n], object_dynamics[n] = enc_rng[n], enc_dyn[n]
    pad_size = extractor.pad_dim
    expand_size = extractor.expand_dim

    ac_norm = PadNormalizationModule(object_range, object_range_true, object_dynamics, instanced, object_names, pad_size, expand_size, all=True)
    return extractor, normalization, ac_norm

class DynamicsAC(Dynamics):
    '''
    Returns the ground truth causal graph by accessing data.true_graph
    '''
    def __init__(self, env, extractor, norm, config, wdb_run=None):
        super().__init__(env, extractor)
        # initialize necessary components
        self.extractor = extractor
        # just in case it produces import errors, import here
        # TODO: ac_infer should be in this folder
        import sys
        if os.path.join(sys.path[0],"Causal", "ac_infer") not in sys.path: sys.path.append(os.path.join(sys.path[0],"Causal", "ac_infer"))  
        from Causal.ac_infer.Model.base_model import InferenceModel
        from Causal.ac_infer.Hyperparam.read_config import read_config
        from Causal.ac_infer.ActualCausal.Updater.update_params import compute_params
        from Causal.ac_infer.ActualCausal.Train.train_model import train_model
        from Causal.ac_infer.ActualCausal.train_loop import pretrain
        from Causal.ac_infer.ActualCausal.Inference.compute_inference import compute_inference
        from Causal.ac_infer.Model.model_utils import load_model
        self.args = read_config(config.dynamics.ac.dynamics_config_path)
        self.args.torch.cuda = torch.cuda.is_available()
        self.args.torch.gpu = config.cuda_id
        torch.cuda.set_device(config.device)
        self.flat_norm_fn = norm
        self.target_object = self.args.inter.train_names[0] if len(self.args.inter.train_names) > 0 else "" # TODO: only supports a single object
        self.target_object_idx = env.all_names.index(self.target_object) if len(self.target_object) > 0 else -1
        extractor, normalization, ac_norm = regenerate(self.args, env, config)
        self.ac_norm_fn = ac_norm
        self.ac_model = InferenceModel(self.args, extractor, normalization, env)
        print(self.ac_model.normalizer)
        if len(self.target_object) > 0: self.ac_model.set_target_name(self.target_object)
        if len(config.load.load_dir) > 0: self.args.record.load_dir = config.load.load_dir
        self.ac_model = load_model(self.ac_model, self.args.record.load_dir, device = self.args.torch.gpu if self.args.torch.cuda else "cpu")
        # print(self.args.torch.gpu, self.args.torch.cuda, self.ac_model.full_modules["Target"].inter_model.inter.pair_layer.conv_layers[0].model[0].weight.data)
        self.compute_params = compute_params
        self.params = self.compute_params(0, self.args, None, result=None)
        self.log_batch = [] # trace, valid could be used here
        self.pretrain_fn = pretrain
        self.train_model_fn = train_model
        self.wdb_run =wdb_run # TODO: attach this to the update code
        self.update_counter = 0
        self.threshold = config.dynamics.ac.graph_threshold
        self.passive_graph = np.eye(config.num_factors)
        self.passive_graph = np.concatenate([np.array([[0] * config.num_factors]).T, self.passive_graph], axis=-1)
        self.discrete_actions = env.discrete_actions
        self.num_actions = env.action_space.n if self.discrete_actions else env.action_space.shape
        self.num_factors =config.num_factors
        self.mask_mode = config.dynamics.ac.mask_mode
        self.compute_inference = compute_inference
        self.train_dynamics = self.args.train.train
        self.use_neg_reward = True if env.name == "AirHockey" else False

    def reset(self):
        self.update_counter = 0

    def merge_ac_observation(self, observation, batch, assign_goal=None):
        # adds three more objects, action, reward, done
        target = self.extractor.slice_targets(observation)
        
        # assigns the goal, if available
        if assign_goal is not None: target[:,2,:assign_goal.shape[-1]] = assign_goal
        # adds action to the observation
        if self.discrete_actions: act_obs = np.eye(self.num_actions)[batch.act]
        else: act_obs = batch.act
        if self.extractor.longest - batch.act.shape[-1] > 0: act_obs = np.pad(act_obs, [(0,0), (0, self.extractor.longest - batch.act.shape[-1])])
        # adds reward and done to the observation
        rewdone = np.zeros((len(batch), 2, self.extractor.longest))
        rewdone[:,0,0] = batch.rew  - 1 * int(self.use_neg_reward)
        rewdone[:,1,0] = batch.done

        merged_observation = np.concatenate([np.expand_dims(act_obs, 1), target, rewdone], axis=1)

        # adding in the IDs, TODO: does not handle objects with multiple classes
        id = np.tile(np.eye(self.num_factors + 3), (len(batch), ) + (1,1)) # TODO: IDs are probably wrong if multiple instances of the same class

        merged_observation_id = np.concatenate([merged_observation, id], axis=-1)
        if len(self.target_object) != 0: 
            # single_target = target[:,self.target_object_idx-1]
            return merged_observation_id.reshape(len(batch), -1)
        else:
            
            return merged_observation_id.reshape(len(batch), -1) # TODO: might actuall need to return all targets concatenated

    def wrap(self, data):
        # takes in data and assigns the necessary keys for ac_infer to run
        # todo: actually implement
        full_trace = data.true_graph
        action_row = np.zeros((len(data), 1,self.num_factors+3))
        action_row[0] = 1
        rew_row = np.zeros((len(data), 1,self.num_factors+3))
        done_row = np.zeros((len(data), 1,self.num_factors+3))
        right_columns = np.zeros((len(data), self.num_factors, 2))
        full_trace = np.concatenate([full_trace, right_columns], axis=-1)
        full_trace = np.concatenate([action_row, full_trace, rew_row, done_row], axis=1)
        if type(data.obs) == Batch: # it should be a batch
            # might need to append object ids to observation, pad length
            denorm_obs, denorm_next_obs = self.flat_norm_fn.denormalize_obs(data.obs.observation), self.flat_norm_fn.denormalize_obs(data.obs_next.observation)
            merged_observation = self.merge_ac_observation(data.obs.observation, data, data.obs.desired_goal)
            batch = Batch(obs=merged_observation, 
                        target=self.extractor.slice_targets(data.obs.observation, append_act=True, append_rew_done=True, flatten=True), 
                        target_next =self.extractor.slice_targets(data.obs_next.observation, append_act=True, append_rew_done=True, flatten=True), 
                        target_diff = self.ac_norm_fn(self.extractor.slice_targets(denorm_next_obs - denorm_obs, append_act=True, append_rew_done=True, flatten=True), name="all", form="dyn"), 
                        valid=np.ones((len(data), self.args.factor.num_objects)), # might need to actually return real valid vectors in the future 
                        trace = full_trace,
                        done=data.done)
        else:
            denorm_obs, denorm_next_obs = self.flat_norm_fn.denormalize_obs(data.obs), self.flat_norm_fn.denormalize_obs(data.obs_next)
            merged_observation = self.merge_ac_observation(data.obs, data)
            batch = Batch(obs=merged_observation, 
                        target=self.extractor.slice_targets(data.obs, append_act=True, append_rew_done=True, flatten=True), 
                        target_next =self.extractor.slice_targets(data.obs_next, append_act=True, append_rew_done=True, flatten=True), 
                        target_diff = self.ac_norm_fn(self.extractor.slice_targets(data.obs_next - data.obs, append_act=True, append_rew_done=True, flatten=True), name="all", form="dyn"), 
                        valid=np.ones((len(data), self.args.factor.num_objects)), 
                        trace = full_trace,
                        done=data.done)
        return batch

    def __call__(self, data):
        batch = self.wrap(data)
        single_target_mode = len(self.target_object) != 0
        if self.mask_mode == "mask_logits":
        # available options: batch, given_mask, infer_type, additional=[], grad_settings=[], log_batch=[], keep_invalid=False, keep_all=False
            form = "mask" if single_target_mode else "all_mask" 
            result = self.ac_model.infer(batch, batch.valid, infer_type = form)[form]

            # round to fixed values according to the threshold
            # result.mask_logits[result.mask_logits >= self.threshold] = 1
            # result.mask_logits[result.mask_logits < self.threshold] = 0
            result.mask_logits[result.mask_logits >= self.threshold] = 1
            result.mask_logits[result.mask_logits < self.threshold] = 0
            result.inter_masks = result.mask_logits
        else:
            result = self.compute_inference(self.args, self.params, self.ac_model, batch, [self.mask_mode], keep_all=True, perform_analysis=-1)[self.mask_mode]
            if single_target_mode: result = result[self.target_object]

        # if single target mode, modify the desired index, otherwise just return as is
        if single_target_mode:
            full_graph = np.tile(self.passive_graph, (len(batch),) + (1,)*len(self.passive_graph.shape))
            if len(batch.done.shape) == 1: omit_flags = (1-(batch.done)).nonzero()
            else: omit_flags = (1-(batch.done[:,0])).nonzero()
            full_graph[omit_flags[0],self.target_object_idx-1] = pytorch_model.unwrap(result.inter_masks[omit_flags[0],:-2])
            full_graph[:,self.target_object_idx-1, self.target_object_idx] = 1 # ignore the passive edge
            # full_graph[:,self.target_object_idx, 0] = 0 # TODO: ignore the action edge?
            return full_graph.astype(bool)
        else:
            return result.inter_masks
    
    def pretrain(self, batch_size, buffer):
        self.pretrain_fn(self.args, self.ac_model, buffer, wrap=self.wrap)

    def update(self, batch_size, buffer):
        # TODO: instead of wrapping the buffer, which is probably expensive
        if self.train_dynamics:
            # edit the train_model and compute_params code so that it can take in the wrap function, and use that instead
                                        # i, args, buffer, pretrain=False, result=None, params = None, weight_binaries=None, passive_weight_binaries=None
            self.params = self.compute_params(self.update_counter, self.args, buffer, pretrain=False, result=result, params=self.params)
                                        # i, args, params, model, train_buffer, log_batch=[], wrap_function=None, intermediate_logger=None
            result = self.train_model_fn(self.update_counter, self.args, self.params, self.ac_model, buffer, log_batch=self.log_batch, wrap_function=self.wrap)
            self.update_counter += 1 # since it's not passed in
            return ObjDict({"bin_error": result.bin_error}) # TODO: use the wdb logger here, and return useful values
        else:
            return ObjDict(dict())

    def compute_weight(self, data, dynamics, graph, true_graph, proximity):
        return np.ones((len(data), ))
